Ensemble Methods (Random Forest) Example

This is a simple example of Ensemble Methods using Python and scikit-learn (Random Forest).

Ensemble Methods Overview

Ensemble Methods combine the predictions of multiple base models to improve the overall performance and robustness. One popular type of Ensemble Method is the Random Forest, which builds a collection of decision trees and combines their predictions. Each tree is trained on a random subset of the data and features, leading to a diverse set of weak learners that collectively form a strong learner.

Key concepts of Random Forest and Ensemble Methods:

Random Forests are widely used for classification and regression tasks due to their robustness and ability to handle complex relationships in the data.

Python Source Code:

# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix

# Generate synthetic classification data
np.random.seed(42)
X, y = make_classification(n_samples=1000, n_features=20, n_informative=10, n_clusters_per_class=2, random_state=42)

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Build a Random Forest model
random_forest = RandomForestClassifier(n_estimators=100, random_state=42)
random_forest.fit(X_train, y_train)

# Make predictions on the test set
y_pred = random_forest.predict(X_test)

# Evaluate the model
accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)

print(f'Accuracy: {accuracy}')
print(f'Confusion Matrix:\n{conf_matrix}')

# Plot feature importances
feature_importances = random_forest.feature_importances_
indices = np.argsort(feature_importances)[::-1]

plt.bar(range(X.shape[1]), feature_importances[indices])
plt.xticks(range(X.shape[1]), indices, rotation=90)
plt.title('Feature Importances')
plt.xlabel('Feature Index')
plt.ylabel('Importance')
plt.show()

Explanation: